/*
 * Javlov - a Java toolkit for reinforcement learning with multi-agent support.
 * 
 * Copyright (c) 2009 Matthijs Snel
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package net.javlov.example.rooms;

import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import javax.swing.Timer;

import net.javlov.*;
import net.javlov.policy.EGreedyPolicy;
import net.javlov.world.*;
import net.javlov.world.grid.*;
import net.javlov.world.phys2d.*;
import net.javlov.world.ui.GridWorldView;
import net.phys2d.raw.shapes.Circle;
import net.phys2d.raw.shapes.StaticBox;
import net.javlov.example.ExperimentGUI;
import net.javlov.example.GridLimitedOptionsWorld;

public class Main implements Runnable {
	
	GridLimitedOptionsWorld world;
	GridRewardFunction rf;
	Simulator sim;
	TabularQFunction qf;
	boolean gui;
	int cellwidth, cellheight;
	
	public static void main(String[] args) {
		Main m = new Main();
		m.gui = true;
		m.init();
		m.start();
	}
	
	public void init() {
		cellwidth = 50; cellheight = cellwidth;

		makeWorld();		
		
		List<? extends Option> optionPool = makeOptions();
		world.setOptionPool(optionPool);
		
		Agent a = makeAgent(optionPool);
		AgentBody aBody = makeAgentBody();
		
		sim = new Simulator();
		sim.setEnvironment(world);
		
		world.add(a, aBody);
		sim.addAgent(a);
	}
	
	protected Agent makeAgent(List<? extends Option> optionPool) {
		qf = TabularQFunction.getInstance(optionPool.size());
		SarsaAgent a = new QLearningAgent(qf, 1);
		a.setLearnRate(0.1);
		Policy pi = new EGreedyPolicy(qf, 0.1, optionPool);
		a.setPolicy(pi);
		a.setSMDPMode(false);
		return a;
	}
	
	protected AgentBody makeAgentBody() {
		Phys2DAgentBody aBody = new Phys2DAgentBody(new Circle(20), 0.5f);
		GridGPSSensor gps = new GridGPSSensor(cellwidth, cellheight);
		gps.setBody(aBody);
		aBody.add(gps);
		return aBody;
	}
	
	protected List<? extends Option> makeOptions() {
		List<Action> primitiveActions = new ArrayList<Action>();
		primitiveActions.add(GridMove.getNorthInstance(world));
		primitiveActions.add(GridMove.getEastInstance(world));
		primitiveActions.add(GridMove.getSouthInstance(world));
		primitiveActions.add(GridMove.getWestInstance(world));
		
		List<Option> optionPool = new ArrayList<Option>();
		Option o = new ReachHallOption("R1H1", 0, 8, 0, 9, new Point(4,10), new Point(9,2), primitiveActions);
		o.setID(optionPool.size());
		optionPool.add( o );
		
		o = new ReachHallOption("R1H4", 0, 8, 0, 9, new Point(9,2), new Point(4,10), primitiveActions);
		o.setID(optionPool.size());
		optionPool.add( o );
		
		o = new ReachHallOption("R2H1", 10, 19, 0, 7, new Point(16,8), new Point(9,2), primitiveActions);
		o.setID(optionPool.size());
		optionPool.add( o );
		
		o = new ReachHallOption("R2H2", 10, 19, 0, 7, new Point(9,2), new Point(16,8), primitiveActions);
		o.setID(optionPool.size());
		optionPool.add( o );
		
		o = new ReachHallOption("R3H2", 10, 19, 9, 19, new Point(9,15), new Point(16,8), primitiveActions);
		o.setID(optionPool.size());
		optionPool.add( o );
		
		o = new ReachHallOption("R3H3", 10, 19, 9, 19, new Point(16,8), new Point(9,15), primitiveActions);
		o.setID(optionPool.size());
		optionPool.add( o );
		
		o = new ReachHallOption("R4H3", 0, 8, 11, 19, new Point(4,10), new Point(9,15), primitiveActions);
		o.setID(optionPool.size());
		optionPool.add( o );
		
		o = new ReachHallOption("R4H4", 0, 8, 11, 19, new Point(9,15), new Point(4,10), primitiveActions);
		o.setID(optionPool.size());
		optionPool.add( o );
		
		return optionPool;
	}
	
	protected void makeWorld() {
		world = new GridLimitedOptionsWorld(20, 20, cellwidth, cellheight);
		
		Body wall = new Phys2DBody( new StaticBox(4*cellwidth, cellheight), 10, true );
		wall.setLocation(2*cellwidth, 10*cellheight+0.5*cellheight);
		wall.setType(Body.OBSTACLE);
		world.addFixedBody(wall);
		
		wall = new Phys2DBody( new StaticBox(4*cellwidth, cellheight), 10, true );
		wall.setLocation(7*cellwidth, 10*cellheight+0.5*cellheight);
		wall.setType(Body.OBSTACLE);
		world.addFixedBody(wall);
		
		wall = new Phys2DBody( new StaticBox(cellwidth, 2*cellheight), 10, true );
		wall.setLocation(9*cellwidth+0.5*cellwidth, cellheight);
		wall.setType(Body.OBSTACLE);
		world.addFixedBody(wall);
		
		wall = new Phys2DBody( new StaticBox(cellwidth, 12*cellheight), 10, true );
		wall.setLocation(9*cellwidth+0.5*cellwidth, 9*cellheight);
		wall.setType(Body.OBSTACLE);
		world.addFixedBody(wall);
		
		wall = new Phys2DBody( new StaticBox(cellwidth, 4*cellheight), 10, true );
		wall.setLocation(9*cellwidth+0.5*cellwidth, 18*cellheight);
		wall.setType(Body.OBSTACLE);
		world.addFixedBody(wall);
		
		wall = new Phys2DBody( new StaticBox(6*cellwidth, cellheight), 10, true );
		wall.setLocation(13*cellwidth, 8*cellheight+0.5*cellheight);
		wall.setType(Body.OBSTACLE);
		world.addFixedBody(wall);
		
		wall = new Phys2DBody( new StaticBox(3*cellwidth, cellheight), 10, true );
		wall.setLocation(18.5*cellwidth, 8*cellheight+0.5*cellheight);
		wall.setType(Body.OBSTACLE);
		world.addFixedBody(wall);
		
		rf = new GridRewardFunction();
		world.setRewardFunction(rf);
		world.addCollisionListener(rf);
		
		GoalBody goal = new GoalBody(825, 425);
		goal.setReward(0);
		world.addFixedBody(goal);
	}
	
	public void start() {
		if ( gui ) {
			GridWorldView wv = new GridWorldView(world);
			Timer timer = new Timer(1000/24, wv);
			ExperimentGUI g = new ExperimentGUI("Rooms example", wv, sim);
			timer.start();
			new Thread(this).start();
		} else
			run();
	}

	@Override
	public void run() {
		int episodes = 5000;
		EpisodicRewardStepStatistic stat = new EpisodicRewardStepStatistic(episodes);
		sim.addStatistic(stat);
		sim.init();
		sim.suspend();
		sim.runEpisodes(episodes);
		System.out.println(Arrays.toString(stat.getRewards()));
		System.out.println(qf);
		
	}

}
